package org.bdware.analysis.example;

import org.bdware.analysis.BasicBlock;
import org.bdware.analysis.BreadthFirstSearch;
import org.bdware.analysis.OpInfo;
import org.bdware.analysis.taint.*;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.JumpInsnNode;

import java.util.*;

// 1. Dependency Analysis ..
// 2. Mulitple Source + Muliple Sink
// 3. Combine Dependency+TaintAnalysis
public class MultiSourceTaintAnalysis extends BreadthFirstSearch<TaintResult, TaintBB> {
    TaintCFG cfg;

    public MultiSourceTaintAnalysis(TaintCFG cfg) {
        this.cfg = cfg;
        List<TaintBB> toAnalysis = new ArrayList<>();
        TaintBB b = (TaintBB) cfg.getBasicBlockAt(0);
        b.preResult = new TaintResult();
        // Put the RootObject into the index 0 of the Local Variable Table
        b.preResult.frame.setLocal(0, HeapObject.getRootObject());
        // The number of parameters of this method minus one
        int arg = cfg.argsLocal();
        // Put the arg into the index 2 of the Local Variable Table
        b.preResult.frame.setLocal(1, new TaintValue(1, cfg.taintBits.allocate("arg")));
        // Find the executeContract functions and allocate taintBits
        cfg.executeLocal();
        TaintResult.interpreter.setTaintBits(cfg.taintBits);
        b.preResult.ret = new TaintValue(1);
        TaintResult.printer.setLabelOrder(cfg.getLabelOrder());
        toAnalysis.add(b);
        b.setInList(true);
        setToAnalysis(toAnalysis);
        if (TaintConfig.isDebug) {
            System.out.println("===Method:" + cfg.getMethodNode().name + cfg.getMethodNode().desc);
            System.out.println(
                    "===Local:"
                            + cfg.getMethodNode().maxLocals
                            + " "
                            + cfg.getMethodNode().maxStack);
        }
    }

    @Override
    public TaintResult execute(TaintBB t) {
        return t.forwardAnalysis();
    }

    @Override
    public Collection<TaintBB> getSuc(TaintBB t) {
        Set<BasicBlock> subBlock = cfg.getSucBlocks(t);
        Set<TaintBB> ret = new HashSet<>();
        for (BasicBlock bb : subBlock) {
            TaintBB ntbb = (TaintBB) bb;
            if (TaintConfig.isDebug)
                System.out.println(
                        "[MultiSoruceTaintAnalysis] B"
                                + ntbb.blockID
                                + " beforeMerge:"
                                + ntbb.preResult.frame2Str());
            ntbb.preResult.mergeResult(t.sucResult);
            if (TaintConfig.isDebug)
                System.out.println(
                        "[MultiSoruceTaintAnalysis] B"
                                + ntbb.blockID
                                + " afterMerge:"
                                + ntbb.preResult.frame2Str());

            ret.add(ntbb);
        }
        return ret;
    }

    /*
     * Dependency Analysis
     */
    public static Map<Integer, List<Integer>> depAnalysis(TaintCFG cfg) {
        // public static Set<Integer> depAnalysis(TaintCFG cfg) {
        // Step1: Find all the jump blocks
        Set<Integer> depBlocks = new HashSet<>();
        for (BasicBlock b : cfg.getBlocks()) {
            Set<BasicBlock> suc = cfg.getSucBlocks(b);
            int sucSize = suc.size();
            // Get rid of some jump nodes such as "GOTO"
            if (sucSize > 1) {
                for (AbstractInsnNode an : b.getInsn()) {
                    if (an instanceof JumpInsnNode) {
                        int blockID = b.blockID;
                        depBlocks.add(blockID);
                    }
                }
            }
        }
        /*Step2: Traverse all the blocks, for each block, traverse all the depBlocks,
         * for each depBlock, traverse all the sucBlocks,
         * if all the sucBlocks are arrival from each block, then the block is dependency with the depBlock.
         */
        List<BasicBlock> blocks = cfg.getBlocks();
        Map<Integer, List<Integer>> map = new HashMap<>();
        for (BasicBlock bb : blocks) {
            List<Integer> list = new ArrayList<>();
            for (int id : depBlocks) {
                Set<BasicBlock> sucBlocks = cfg.getSucBlocks(cfg.getBasicBlockAt(id));
                int sucSize = sucBlocks.size();
                int count = 0;
                for (BasicBlock sucBlock : sucBlocks) {
                    if (isArrival(cfg, sucBlock, bb)) count++;
                }
                if (count > 0 && count != sucSize) list.add(id);
            }
            map.put(bb.blockID, list);
        }

        Map<Integer, List<Integer>> returnMap = new HashMap<>();
        for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
            // System.out.print(entry.getKey() +" : ");
            BasicBlock tmp = cfg.getBasicBlockAt(entry.getKey());
            for (AbstractInsnNode an : tmp.getInsn()) {
                int opcode;
                if (an instanceof InsnNode) {
                    opcode = an.getOpcode();
                    if (OpInfo.ops[opcode].canReturn()) {
                        returnMap.put(entry.getKey(), entry.getValue());
                    }
                }
            }
            /*
            for(int i : entry.getValue())
            	System.out.print(i+ " ");
            System.out.println();
            */
        }

        // add dependence to the last block
        if (returnMap != null) {
            List<Integer> lastBlockDep = new ArrayList<>();
            for (Map.Entry<Integer, List<Integer>> entry : returnMap.entrySet()) {
                List<Integer> listID = entry.getValue();
                for (Integer i : listID) {
                    if (!lastBlockDep.contains(i)) lastBlockDep.add(i);
                }
            }
            returnMap.put(cfg.getBasicBlockSize() - 1, lastBlockDep);
        }

        // TODO get value from branch block
        /*
        List<Integer> list = returnMap.get(cfg.getBasicBlockSize()-1);
        for(int i : list) {
        	System.out.println(i);
        	TaintBB tb = (TaintBB) cfg.getBasicBlockAt(i);
        	System.out.println("Test:");
        	tb.preResult.printResult();
        }
        */

        return returnMap;
    }

    public static boolean isArrival(TaintCFG cfg, BasicBlock suc, BasicBlock bb) {
        // Test
        /*
        if(suc.blockID == bb.blockID)
        	return true;
        if(suc.blockID == 7 && bb.blockID == 11)
        	return true;
        if(suc.blockID == 9 && bb.blockID == 11)
        	return true;
        if(suc.blockID == 7 && bb.blockID == 12)
        	return true;
        if(suc.blockID == 9 && bb.blockID == 12)
        	return true;
        return false;
        */

        if (suc.blockID == bb.blockID) return true;
        DirectGraphDFS dgDFS = new DirectGraphDFS(cfg, suc);
        return dgDFS.isArrival(bb);
    }
}
